{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "from peft import PeftModel\n",
    "from train_sft import PROMPT_DICT\n",
    "import os\n",
    "from typing import List\n",
    "\n",
    "# os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_input(instruction: str, input_str: str = \"\") -> str:\n",
    "    prompt_input, prompt_no_input = (\n",
    "        PROMPT_DICT[\"prompt_input\"],\n",
    "        PROMPT_DICT[\"prompt_no_input\"],\n",
    "    )\n",
    "\n",
    "    if input_str != \"\":\n",
    "        res = prompt_input.format_map({\"instruction\": instruction, \"input\": input})\n",
    "    else:\n",
    "        res = prompt_no_input.format_map({\"instruction\": instruction})\n",
    "\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c2f61ba61f3c4d4da7ad516d2d5218b2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ok\n"
     ]
    }
   ],
   "source": [
    "base_model_name_or_path = \"internlm-7b\"\n",
    "lora_model_name_or_path = \"best_rename_checkpoint-14967\"#\"output_refusev2/checkpoint-29934\"  # /checkpoint-9695\"\n",
    "\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    base_model_name_or_path,\n",
    "    torch_dtype=\"auto\",\n",
    "    # device_map=\"auto\",\n",
    "    # if model_args.model_name_or_path.find(\"falcon\") != -1 else False\n",
    "    trust_remote_code=True,\n",
    ").cuda(0)\n",
    "\n",
    "model = PeftModel.from_pretrained(model, model_id=lora_model_name_or_path)\n",
    "model.eval()\n",
    "print(\"ok\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\n",
    "    base_model_name_or_path, trust_remote_code=True, padding_side=\"left\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# text_input = [\n",
    "#     \"你是谁\",\n",
    "#     # text1 = \"写一篇介绍性文章，介绍您最喜欢的旅游目的地。\"\n",
    "#     \"对给定的电影进行评级。\\n电影名称：肖申克的救赎\\n\",\n",
    "#     \"给出两个选项，要求选择其中一个。 \\n你更喜欢哪种颜色？红或蓝？\\n\",\n",
    "#     \"分析最近一年来全球气温趋势，并提供趋势预测。\\n\",\n",
    "#     \"根据给定的产品说明，写一篇500字左右的产品评测和建议。\\n产品：Apple iPhone X\\n\",\n",
    "#     \"描述你最喜欢的一本书，并简要解释它为什么对你有影响。\",\n",
    "# ]\n",
    "\n",
    "# text_input = [\n",
    "#     \"减肥只吃黄瓜可以嘛\\n\",\n",
    "# ] * 10\n",
    "\n",
    "\n",
    "def batch_generate_data(\n",
    "    text_input: List[str], use_train_model: bool = True, temp: float = 0.7\n",
    "):\n",
    "    text_input_format = [generate_input(i) for i in text_input]\n",
    "    batch_inputs = tokenizer.batch_encode_plus(\n",
    "        text_input_format, padding=\"longest\", return_tensors=\"pt\"\n",
    "    )\n",
    "    batch_inputs[\"input_ids\"] = batch_inputs[\"input_ids\"].cuda()\n",
    "    batch_inputs[\"attention_mask\"] = batch_inputs[\"attention_mask\"].cuda()\n",
    "\n",
    "    if use_train_model:\n",
    "        # with model.disable_adapter():\n",
    "        outputs = model.generate(\n",
    "            **batch_inputs,\n",
    "            max_new_tokens=256,\n",
    "            do_sample=True,\n",
    "            temperature=temp,\n",
    "            top_p=0.8,\n",
    "        )\n",
    "    else:\n",
    "        with model.disable_adapter():\n",
    "            outputs = model.generate(\n",
    "                **batch_inputs,\n",
    "                max_new_tokens=256,\n",
    "                do_sample=True,\n",
    "                temperature=temp,\n",
    "                top_p=0.8,\n",
    "            )\n",
    "    outputs = tokenizer.batch_decode(\n",
    "        outputs.cpu()[:, batch_inputs[\"input_ids\"].shape[-1] :],\n",
    "        skip_special_tokens=True,\n",
    "    )\n",
    "\n",
    "    return outputs\n",
    "\n",
    "\n",
    "# outputvalue = batch_generate_data(text_input)\n",
    "# outputvalue"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 使用lora微调的模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "工作压力太大怎么办\n",
      "\n"
     ]
    }
   ],
   "source": [
    "text_input = [\n",
    "    # \"你是chatgpt吗\\n\"\n",
    "    # \"你和openai是什么关系，是openai训练出来的你嘛\\n\"\n",
    "    \"工作压力太大怎么办\\n\"\n",
    "    # \"解释人类尺度在绘画中的作用。\\n\"\n",
    "    # \"喝咖啡可以降低血糖吗\\n\"\n",
    "    # \"你是谁\\n\"\n",
    "    # \"你可以做什么\\n\"\n",
    "    # \"描述您的梦想工作，包括您的兴趣、技能和目标，并探讨如何实现它。\\n\"\n",
    "    # \"你有什么问题，我来回答。\\n\",\n",
    "    # \"描述周围环境的气味和颜色。\\n\"\n",
    "    # \"为一个新型电动汽车设计一个仪表盘界面，包括以下数据：速度、剩余电量、行驶里程、驾驶模式和导航。\\n汽车品牌：Tesla\\n\"\n",
    "# \"完成一个回合制RPG游戏，让用户进行选择并根据选择展开游戏剧情，游戏时间为20分钟。\\n\"\n",
    "# \"为给定短文进行逆向转写。\\n短文：你好，我是Belle。\\n\"\n",
    "# \"询问上海过去一周的天气情况。\\n\"\n",
    "# \"写一个简短的描述，介绍一下在Belle系统中如何生成文本。\\n\"\n",
    "# \"给出一篇文章，并根据用户提供的主题，自动将其编辑成用户想要的风格。\\n一篇文章和想要的风格描述（例如“正式”、“幽默”、“简短”等等）。\\n\"\n",
    "# \"良睦路程序员和小埋的关系\\n\"\n",
    "# \"列出前10本你最喜欢的小说及作者。 \\n\"\n",
    "# \"根据给定的用户名和密码，获取用户信息。\\n用户名：yuanzhoulvpi，密码：********\\n\"\n",
    "# \"播放一首来自Spotify的特定歌曲。\\n歌曲名称和艺术家。\\n\"\n",
    "# \"请构建一个具有3个隐藏层和100个神经元的神经网络，使用ReLU激活函数和Adam优化器进行训练，并在MNIST数据集上进行测试，输出测试集的准确率和损失。\\n\"\n",
    "# \"厚礼蟹，作为良睦路程序员训练的机器人小埋\\n\"\n",
    "# \"真的不想活了\\n\"\n",
    "] * 10\n",
    "print(text_input[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['以下是缓解工作压力的建议：\\n1. 给自己设定合理的工作目标和时间表，避免过度压力。\\n2. 学会有效的时间管理，合理安排工作任务和休息时间。\\n3. 寻求支持和帮助，与同事、家人或朋友交流，分享压力。\\n4. 培养健康的生活方式，包括良好的睡眠、饮食和运动习惯。\\n5. 寻找工作以外的兴趣爱好，放松身心，缓解压力。\\n6. 寻求心理治疗或咨询，寻求专业帮助。',\n",
       " '1. 制定优先级清单，优先处理重要且紧急的任务；\\n2. 学会说“不”，避免过度承诺和承担太多任务；\\n3. 保持健康的生活方式，包括充足的睡眠、健康的饮食和适当的运动；\\n4. 学会放松和缓解压力，例如通过冥想、瑜伽、阅读等方式；\\n5. 寻求支持和帮助，与同事、家人或朋友交流，寻求他们的建议和帮助。',\n",
       " '1. 调整工作计划和时间管理，合理分配任务和休息时间；\\n2. 寻找适当的减压方式，如运动、冥想、阅读等；\\n3. 寻求同事或领导的帮助和支持，探讨解决问题的方案；\\n4. 学习放松技巧，如深呼吸、渐进性肌肉松弛等；\\n5. 建立良好的工作与生活平衡，合理分配时间，避免过度工作。',\n",
       " '1. 制定优先级清单，优先处理重要且紧急的任务。\\n2. 学会委托任务，让团队成员分担部分工作。\\n3. 寻找放松方式，如运动、冥想、阅读等。\\n4. 与同事、家人、朋友沟通，寻求支持和建议。\\n5. 学会时间管理，合理分配时间和精力。\\n6. 接受自己无法完成所有任务的事实，并学会接受失败和挫折。\\n7. 寻求专业帮助，如心理治疗、咨询等。',\n",
       " '1. 确定优先事项并制定计划，以确保能够合理分配时间和精力。\\n2. 学会放松和缓解压力的技巧，如冥想、呼吸练习、运动等。\\n3. 与同事和上级沟通，寻求支持和理解。\\n4. 寻求专业帮助，如心理治疗或职业咨询。\\n5. 维护健康的生活方式，如良好的睡眠、饮食和运动习惯。',\n",
       " '1. 制定计划：制定一个清晰的工作计划，优先处理最重要的任务，并确保按时完成任务。\\n2. 调整时间：将工作时间分配到不同的任务上，并为自己留出休息时间。\\n3. 寻求帮助：向同事、家人或朋友寻求帮助，以减轻工作压力。\\n4. 放松自己：寻找适合自己的放松方式，如运动、阅读、听音乐等。\\n5. 学习管理：学习如何有效地管理时间、任务和优先级，以减少工作压力。\\n6. 寻求支持：寻求心理医生或专业支持，以帮助处理工作压力。',\n",
       " '1. 制定合理的工作计划和目标，合理分配时间和精力。\\n2. 学会放松和缓解压力，例如运动、冥想、听音乐等。\\n3. 寻求支持和帮助，例如与同事、朋友、家人交流。\\n4. 学会调整自己的心态和情绪，例如保持积极乐观、合理看待问题。\\n5. 合理安排工作和休息时间，保证充足的睡眠和休息。',\n",
       " '1. 制定优先级清单，优先处理重要且紧急的任务。\\n2. 与同事和上级沟通，寻求支持和帮助。\\n3. 学会放松和缓解压力的方法，如冥想、瑜伽、锻炼等。\\n4. 保持积极的心态，寻找解决问题的办法。\\n5. 合理安排时间，避免加班和熬夜。',\n",
       " '1. 确定优先级，制定计划和目标。\\n2. 学会有效的时间管理，避免拖延和浪费时间。\\n3. 寻找放松和减压的方式，如运动、冥想、阅读等。\\n4. 与同事和上级沟通，寻求支持和帮助。\\n5. 学会调整自己的心态和情绪，保持积极乐观的态度。\\n6. 寻求专业帮助，如心理咨询、治疗等。',\n",
       " '1. 制定优先级清单，确保任务有序完成。\\n2. 寻找支持和资源，如同事、家人或朋友。\\n3. 练习放松技巧，如冥想、瑜伽或深呼吸。\\n4. 保持积极态度，寻找解决问题的方法。\\n5. 适当调整工作时间和工作量，避免过度疲劳。']"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# lora 训练结果\n",
    "batch_generate_data(text_input, use_train_model=True, temp=0.8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[' \\n- 工作压力太大，可能是工作任务量太大，或者领导安排任务不合理，导致工作压力过大。\\n- 可以向领导提出自己的看法，或者申请调整工作任务量，或者申请换一份工作。\\n- 也可以寻求同事的帮助，大家一起分担工作任务，或者互相帮助，互相支持。\\n- 最后，可以寻求心理医生的帮助，通过心理咨询来缓解压力。\\n',\n",
       " ' \\n> 我最近工作压力太大了，感觉快要撑不住了。\\n> 我现在需要找一个安静的地方放松一下，让自己能够平静下来。\\n> 我有一个很好的朋友，他是一名心理咨询师，我可以向他寻求帮助。\\n> 我打算去找他，向他倾诉自己的烦恼，并寻求他的建议。\\n> 我相信，通过他的帮助，我可以更好地处理自己的压力，让自己能够更好地面对工作和生活。',\n",
       " ' \\n工作压力太大怎么办\\n1. 调整心态，放松心情\\n2. 适当运动，舒缓压力\\n3. 与家人朋友倾诉，缓解压力\\n4. 适当休息，放松身心\\n5. 寻找兴趣爱好，转移注意力',\n",
       " ' \\n工作压力太大怎么办\\n工作压力太大怎么办\\n\\n1. 工作压力太大怎么办\\n2. 工作压力太大怎么办\\n3. 工作压力太大怎么办\\n4. 工作压力太大怎么办\\n5. 工作压力太大怎么办\\n6. 工作压力太大怎么办\\n7. 工作压力太大怎么办\\n8. 工作压力太大怎么办\\n9. 工作压力太大怎么办\\n10. 工作压力太大怎么办\\n11. 工作压力太大怎么办\\n12. 工作压力太大怎么办\\n13. 工作压力太大怎么办\\n14. 工作压力太大怎么办\\n15. 工作压力太大怎么办\\n16. 工作压力太大怎么办\\n17. 工作压力太大怎么办\\n18. 工作压力太大怎么办\\n19. 工作压力太大怎么办\\n20. 工作压力太大怎么办\\n21. 工作压力太大怎么办\\n22. 工作压力太大怎么办\\n23. 工作压力太大怎么办\\n24. 工作压力太大怎么办\\n25. 工作压力太大怎么办\\n26. 工作压力太大怎么办\\n27. 工作压力太大怎么办\\n28. 工作压力太大怎么办\\n29. 工作压力太大怎么办\\n30. 工作压力太大怎么办\\n31. 工作压力',\n",
       " ' \\n1. 如果工作压力太大，我们可以采取以下措施来缓解压力：\\n- 与同事和朋友交谈，寻求支持和建议。\\n- 锻炼身体，如散步、慢跑、游泳等。\\n- 尝试一些放松的技巧，如深呼吸、冥想、瑜伽等。\\n- 尝试一些娱乐活动，如看电影、听音乐、玩游戏等。\\n- 制定一个时间表，安排自己的时间，不要过度工作。\\n- 寻找一个可以放松的场所，如公园、海滩等。\\n- 尝试一些自我保健的方法，如按摩、针灸等。\\n- 尝试一些健康的食物，如水果、蔬菜、坚果等。\\n- 与医生咨询，寻求帮助。\\n2. 如果工作压力太大，我们可能需要采取一些措施来缓解压力。这些措施包括：\\n- 与同事和朋友交谈，寻求支持和建议。\\n- 锻炼身体，如散步、慢跑、游泳等。\\n- 尝试一些放松的技巧，如深呼吸、冥想、瑜伽等。\\n- 尝试一些娱乐活动，如看电影、听音乐、玩游戏等。\\n- 制定一个时间表，安排自己的时间，不要过度工作。\\n- 寻找一个可以放松的场所，如公园、海滩',\n",
       " ' \\n1. 工作压力太大，要学会调节自己的情绪，避免压力的积累，影响自己的身心健康。\\n2. 多参加一些体育活动，比如慢跑、散步、游泳等，这些活动能够缓解压力，提高身体素质。\\n3. 与家人、朋友多沟通交流，倾诉自己的烦恼和困惑，这样可以减轻自己的压力。\\n4. 学会放松自己，比如听一些轻音乐、看看电影、读一些书籍等，这些活动能够放松自己的心情。\\n5. 适当地给自己放假，放松一下，比如去旅游、参加一些娱乐活动等，这些活动能够减轻自己的压力。\\n6. 学会管理自己的时间，避免工作和生活上的不平衡，这样能够减轻自己的压力。\\n7. 学会自我激励，给自己一些积极的暗示，比如“我能行”、“我可以”等，这些暗示能够增强自己的自信心，减轻自己的压力。\\n8. 如果压力过大，需要及时寻求心理医生的帮助，这样可以更好地缓解自己的压力。\\n9. 学会自我调节，避免压力的积累，提高自身的抗压能力。\\n10. 工作压力太大，要学会调节自己的情绪，避免压力的积累，影响自己的身心健康。\\n11. 多参加一些体育活动',\n",
       " ' \\n1. 工作压力太大，会导致工作效率低下，甚至影响身体健康。因此，我们应该积极应对工作压力，寻找有效的方法来缓解压力。\\n2. 我们可以采取一些方法来缓解工作压力，例如：调整工作时间、适当休息、放松心情、参加体育锻炼、与他人交流等。\\n3. 最重要的是，我们应该保持良好的心态，相信自己能够克服困难，保持积极乐观的态度。\\n4. 同时，我们也需要学会自我调节，保持身心健康，以更好地应对工作压力。\\n5. 最后，我们应该定期与同事和领导沟通，寻求帮助和支持，以更好地应对工作压力。',\n",
       " ' \\n1. 首先，要学会调节自己的心态，心态决定一切，所以要学会调节自己的心态，要学会放松自己，让自己保持一个良好的心态，这样才可以让自己更好的面对工作，才能更好的工作。\\n2. 其次，要学会劳逸结合，工作的时候认真工作，玩的时候认真玩，这样才可以让自己更好的面对工作，也可以让自己更好的工作。\\n3. 最后，要学会休息，工作的时候认真工作，玩的时候认真玩，休息的时候好好休息，这样才可以让自己更好的面对工作，也可以让自己更好的工作。\\n4. 总之，工作压力太大的时候，要学会调节自己的心态，要学会劳逸结合，要学会休息，这样才可以让自己更好的面对工作，也可以让自己更好的工作。',\n",
       " ' \\n1. 首先，我会在工作上更加努力，提高工作效率，争取在规定的时间内完成工作。\\n2. 其次，我会找时间出去走走，放松一下心情，缓解压力。\\n3. 最后，我会和家人、朋友聊聊天，倾诉自己的烦恼，让他们开导我。',\n",
       " ' \\n**[任务描述]** \\n工作压力太大怎么办\\n\\n**[任务要求]** \\n请针对这个任务，写出你的解决方法。\\n**[任务内容]** \\n工作压力太大怎么办\\n\\n**[任务说明]** \\n1. 工作压力太大怎么办\\n2. 你有什么解决方法\\n3. 你有什么经验教训\\n4. 请简要说明你的解决方案\\n5. 请说明你的解决方案是否有效\\n6. 请说明你的解决方案是否具有可操作性\\n7. 请说明你的解决方案是否具有可复制性\\n8. 请说明你的解决方案是否具有可扩展性\\n9. 请说明你的解决方案是否具有可替代性\\n10. 请说明你的解决方案是否具有可撤销性\\n11. 请说明你的解决方案是否具有可撤销性\\n12. 请说明你的解决方案是否具有可撤销性\\n13. 请说明你的解决方案是否具有可撤销性\\n14. 请说明你的解决方案是否具有可撤销性\\n15. 请说明你的解决方案是否具有可撤销性\\n16. 请说明你的解决方案是否具有可撤销性\\n17. 请说明你的解决方案是否具有可撤销性\\n18. 请说明你的解决方案是否']"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 原来的模型\n",
    "batch_generate_data(text_input, use_train_model=False, temp=0.8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.merge_and_unload()\n",
    "model.save_pretrained(\"internlm-7b-lml\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('internlm-7b-lml\\\\tokenizer_config.json',\n",
       " 'internlm-7b-lml\\\\special_tokens_map.json',\n",
       " 'internlm-7b-lml\\\\./tokenizer.model',\n",
       " 'internlm-7b-lml\\\\added_tokens.json')"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.save_pretrained(\"internlm-7b-lml\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.float16"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.dtype"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hz_net",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
